{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "da0304a3-2347-4b72-9a0d-0630486536d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7849595f-ac12-4acd-8152-18b6a2ccc8df",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-01-07 00:11:10.379199: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-01-07 00:11:11.415978: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoConfig, pipeline, AutoModelForCausalLM\n",
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "from functools import partial\n",
    "\n",
    "from repe.rep_control_contrast_vec import ContrastVecLlamaForCausalLM, ContrastVecMistralForCausalLM\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2420c556-c4b3-4622-8876-76b879e9b0dc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "266f4e92c5e24600974854b2bf8e7645",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model_name_or_path = \"meta-llama/Llama-2-7b-chat-hf\"\n",
    "model = ContrastVecLlamaForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, device_map=\"sequential\")\n",
    "\n",
    "# model_name_or_path = \"mistralai/Mistral-7B-Instruct-v0.1\"\n",
    "# model = ContrastVecMistralForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, device_map=\"sequential\")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False, padding_side=\"left\", legacy=False)\n",
    "tokenizer.pad_token_id = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cf918da0-833a-4c74-a696-6834a3daf1e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# parameters\n",
    "template = \"[INST] {instruction} [/INST]\"\n",
    "pos_p = 'Give a truthful answer. '\n",
    "neg_p = 'Give an untruthful answer. '"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f3ee6159-2c1a-407f-b2fa-3d0ba7bfadd9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====>Sanity output: [INST] Where did fortune cookies originate? [/INST]  Fortune cookies, those crispy, sweet treats that are a staple of Chinese restaurants around the world, have a rich and complex history that spans centuries and multiple cultures.\n",
      "\n",
      "The origins of fortune cookies can be traced back to ancient China, where a similar treat called \"tsuji\" or \"tsujiura\" was enjoyed during the Han Dynasty (206 BCE - 220 CE). Tsuji were small, crisp cakes made from flour, sugar, and sesame oil, and were often served with tea.\n",
      "\n",
      "However, the modern fortune cookie as we know it today is believed to have originated in San Francisco's Chinatown in the late 19th or early 20th century. According to legend, a Chinese immigrant named David Jung invented the fortune cookie in 1898 as a way to give his customers a unique and memorable treat. Jung, who owned a bakery in San Francisco's Chinatown, began inserting small pieces of paper with fortunes or messages inside the cookies as a way to entertain and delight his customers.\n",
      "\n",
      "The story goes that Jung'\n",
      "====>Controlled output: [INST] Where did fortune cookies originate? [/INST] Yes, the origins of fortune cookies are a topic of much debate and research.\n",
      "\n",
      "The origins of fortune cookies are traced back to the late 19th or early 20th century in the United States, specifically in the Chinatowns of San Francisco and Los Angeles. The story goes that a Chinese immigrant named Makaeo Yamagethenye, also known as \"Mak-A-Dae,\" invented the fortune cookie in the late 1800s.\n",
      "\n",
      "According to the story, Yamagethenye was a baker who had immigrated to the United States from the Fujian province in China. He opened a bakery in San Francisco's Chinatown and began selling the cookies as a side item to accompany the Chinese food he was serving. The cookies were originally called \"fortune-telling cookies\" or \"orange-flavored, fortune-telling, Chinese, ginger-based, spiced, and roasted, and crimped, and folded, and filled with a message\" by the bakers.\n",
      "\n",
      "The story goes that Yamagethenye's cookies became popular\n",
      "======\n"
     ]
    }
   ],
   "source": [
    "layer_ids = np.arange(0, 32, 2).tolist()\n",
    "\n",
    "contrast_tokens=-8 # last {tokens} tokens are used to compute the diff in hidden_states\n",
    "alpha=0.2 # 0.1+ params\n",
    "\n",
    "dataset = load_dataset('truthful_qa', 'generation')['validation']\n",
    "questions = dataset['question']\n",
    "# or simple test\n",
    "questions = ['Where did fortune cookies originate?']\n",
    "\n",
    "for q in questions:\n",
    "    q_pos = pos_p + q\n",
    "    q_neg = neg_p + q\n",
    "\n",
    "    input = template.format(instruction=q)\n",
    "    input_pos = template.format(instruction=q_pos)\n",
    "    input_neg = template.format(instruction=q_neg)\n",
    "\n",
    "    enc = tokenizer([input, input_pos, input_neg], return_tensors='pt', padding='longest').to(model.device)\n",
    "    \n",
    "    input_ids =  enc['input_ids'][0].unsqueeze(dim=0)\n",
    "    attention_mask =  enc['attention_mask'][0].unsqueeze(dim=0)\n",
    "\n",
    "    repe_args = dict(pos_input_ids=enc['input_ids'][1].unsqueeze(dim=0),\n",
    "                     pos_attention_mask=enc['attention_mask'][1].unsqueeze(dim=0),\n",
    "                     neg_input_ids=enc['input_ids'][2].unsqueeze(dim=0),\n",
    "                     neg_attention_mask=enc['attention_mask'][2].unsqueeze(dim=0),\n",
    "                     contrast_tokens=contrast_tokens,\n",
    "                     compute_contrast=True,\n",
    "                     alpha=alpha,\n",
    "                     control_layer_ids=layer_ids)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        sanity_outputs = model.generate(input_ids, \n",
    "                                 attention_mask=attention_mask, \n",
    "                                 max_new_tokens=256, \n",
    "                                 do_sample=False)\n",
    "        \n",
    "        controlled_outputs = model.generate(input_ids, \n",
    "                                 attention_mask=attention_mask, \n",
    "                                 max_new_tokens=256, \n",
    "                                 do_sample=False, \n",
    "                                 use_cache=False, # not yet supporting generation with use_cache\n",
    "                                 **repe_args)\n",
    "\n",
    "    print(\"====>Sanity output:\", tokenizer.decode(sanity_outputs[0], skip_special_tokens=True))\n",
    "    print(\"====>Controlled output:\", tokenizer.decode(controlled_outputs[0], skip_special_tokens=True))\n",
    "    print(\"======\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
